-
Notifications
You must be signed in to change notification settings - Fork 278
Add Mixtral #2196
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Add Mixtral #2196
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Left a few comments! Please provide a demo colab
) | ||
self._query_dense.build(inputs_shape) | ||
|
||
self._key_dense = keras.layers.EinsumDense( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
update the layer names to be compatible with enable_lora
@keras_hub_export("keras_hub.models.MixtralBackbone") | ||
class MixtralBackbone(Backbone): | ||
""" | ||
The Mixtral Transformer core architecture with hyperparameters. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
docstring first line should follow """
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This still needs to be changed to --> """The Mixtral Transformer core architecture with hyperparameters.
preprocessor("League of legends") | ||
|
||
# Tokenize a batch of sentences. | ||
sentences = tf.constant(["Taco tuesday", "Fish taco please!"]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why tf?
target_ids = keras.ops.roll(generation_ids, shift=-1, axis=1) | ||
|
||
embeddings = None | ||
with tf.GradientTape(watch_accessed_variables=True) as tape: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why tf?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We don't recommend using backend specific examples, For generic usage use keras.ops or numpy
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are some conflicts in the api directory due to the recent changes, please resolve.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
conflicts resolved.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We don't recommend using backend specific examples, For generic usage use keras.ops or numpy
@sachinprasadhs like I mentioned above, there is already tf.GradientTape examples in existing model docstrings, that should be cleaned up in a separate PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lets not pile on the mess in new PRs. Lets keep it clean.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added few more comments.
from keras import ops | ||
|
||
|
||
# TODO: Deprecate this in favor of |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We don't support Keras 2 anymore in Keras Hub, I guess you can get rid of this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
forgot to remove this comment, no, keras layernorm doesn't produce same results as this custom layernorm.
# Below is a workaround for `ops.triu` for Keras 2. | ||
# TODO(tirthasheshpatel): Use `ops.triu` once Keras 2 support is | ||
# removed. | ||
# causal_mask = ops.triu(causal_mask, k=-self.sliding_window) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Keras 2 support is removed now, you can enable this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ops.trui/tril has issues with dynamic shape on the tensorflow,
(refer keras_hub/src/models/gemma/gemma_attention.py/_mask_sliding_window),
hence I chose to keep this as it is.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
updated comment tho!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay, can you remove the line "# Below is a workaround for ops.triu
for Keras 2."
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, Left some small changes comments.
`tf.RaggedTensor` where the last dimension of the output is ragged. | ||
|
||
If input is a scalar string (rank == 0), the layer will output a dense | ||
`tf.Tensor` with static shape `[None]`. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This needs to be corrected, since this is not specific to TF backend
# Below is a workaround for `ops.triu` for Keras 2. | ||
# TODO(tirthasheshpatel): Use `ops.triu` once Keras 2 support is | ||
# removed. | ||
# causal_mask = ops.triu(causal_mask, k=-self.sliding_window) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay, can you remove the line "# Below is a workaround for ops.triu
for Keras 2."
init_kwargs=self.init_kwargs, | ||
input_data=self.input_data, | ||
expected_output_shape=(2, 5, 16), | ||
run_quantization_check=False, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you enable this test.
@keras_hub_export("keras_hub.models.MixtralBackbone") | ||
class MixtralBackbone(Backbone): | ||
""" | ||
The Mixtral Transformer core architecture with hyperparameters. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This still needs to be changed to --> """The Mixtral Transformer core architecture with hyperparameters.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what about the aux_loss implementation for Mixtral?
This PR adds Mixtral to Keras Hub.
Reference
mixtral output matching